{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "os.sys.path.append('..')\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.misc\n",
    "import warnings\n",
    "import sys\n",
    "import argparse\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import dataset_multi\n",
    "from darknet_multi import Darknet\n",
    "from utils_multi import *\n",
    "from cfg import parse_cfg\n",
    "from MeshPly import MeshPly\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.misc import imsave\n",
    "import scipy.io\n",
    "import scipy.misc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2019-10-18 17:00:04 Testing ape...\n",
      "2019-10-18 17:01:38    Acc using 5 px 2D Projection = 6.07%\n",
      "2019-10-18 17:01:38    Acc using 10 px 2D Projection = 39.32%\n",
      "2019-10-18 17:01:38    Acc using 15 px 2D Projection = 59.83%\n",
      "2019-10-18 17:01:38    Acc using 20 px 2D Projection = 68.29%\n",
      "2019-10-18 17:01:38    Acc using 25 px 2D Projection = 72.74%\n",
      "2019-10-18 17:01:38    Acc using 30 px 2D Projection = 74.96%\n",
      "2019-10-18 17:01:38    Acc using 35 px 2D Projection = 75.64%\n",
      "2019-10-18 17:01:38    Acc using 40 px 2D Projection = 76.32%\n",
      "2019-10-18 17:01:38    Acc using 45 px 2D Projection = 76.67%\n",
      "2019-10-18 17:01:38    Acc using 50 px 2D Projection = 78.03%\n",
      "2019-10-18 17:01:39 Testing can...\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-5-27dc0bb0bf4c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m    146\u001b[0m \u001b[0mvalid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodelcfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    147\u001b[0m \u001b[0mdatacfg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'cfg/can_occlusion.data'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 148\u001b[1;33m \u001b[0mvalid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodelcfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    149\u001b[0m \u001b[0mdatacfg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'cfg/cat_occlusion.data'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    150\u001b[0m \u001b[0mvalid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodelcfg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-5-27dc0bb0bf4c>\u001b[0m in \u001b[0;36mvalid\u001b[1;34m(datacfg, cfgfile, weightfile)\u001b[0m\n\u001b[0;32m     74\u001b[0m         \u001b[1;31m# Using confidence threshold, eliminate low-confidence predictions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     75\u001b[0m         \u001b[0mtrgt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_labels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m         \u001b[0mall_boxes\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_multi_region_boxes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconf_thresh\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_keypoints\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0manchors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_anchors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrgt\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0monly_objectness\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     77\u001b[0m         \u001b[0mt4\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     78\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Documents\\Code\\singleshot6Dpose\\multi_obj_pose_estimation\\utils_multi.py\u001b[0m in \u001b[0;36mget_multi_region_boxes\u001b[1;34m(output, conf_thresh, num_classes, num_keypoints, anchors, num_anchors, correspondingclass, only_objectness, validation)\u001b[0m\n\u001b[0;32m    330\u001b[0m                         \u001b[0mmax_ind\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mind\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    331\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 332\u001b[1;33m                     \u001b[1;32mif\u001b[0m \u001b[0mconf\u001b[0m \u001b[1;33m>\u001b[0m \u001b[0mconf_thresh\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    333\u001b[0m                         \u001b[0mbcx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    334\u001b[0m                         \u001b[0mbcy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def valid(datacfg, cfgfile, weightfile):\n",
    "    def truths_length(truths):\n",
    "        for i in range(50):\n",
    "            if truths[i][1] == 0:\n",
    "                return i\n",
    "\n",
    "    # Parse data configuration files\n",
    "    data_options = read_data_cfg(datacfg)\n",
    "    valid_images = data_options['valid']\n",
    "    meshname     = data_options['mesh']\n",
    "    name         = data_options['name']\n",
    "    im_width     = int(data_options['im_width'])\n",
    "    im_height    = int(data_options['im_height']) \n",
    "    fx           = float(data_options['fx'])\n",
    "    fy           = float(data_options['fy'])\n",
    "    u0           = float(data_options['u0'])\n",
    "    v0           = float(data_options['v0'])\n",
    "    \n",
    "    # Parse net configuration file\n",
    "    net_options   = parse_cfg(cfgfile)[0]\n",
    "    loss_options  = parse_cfg(cfgfile)[-1]\n",
    "    conf_thresh   = float(net_options['conf_thresh'])\n",
    "    num_keypoints = int(net_options['num_keypoints'])\n",
    "    num_classes   = int(loss_options['classes'])\n",
    "    num_anchors   = int(loss_options['num'])\n",
    "    anchors       = [float(anchor) for anchor in loss_options['anchors'].split(',')]\n",
    "\n",
    "    # Read object model information, get 3D bounding box corners, get intrinsics\n",
    "    mesh                  = MeshPly(meshname)\n",
    "    vertices              = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
    "    corners3D             = get_3D_corners(vertices)\n",
    "    diam                  = float(data_options['diam'])\n",
    "    intrinsic_calibration = get_camera_intrinsic(u0, v0, fx, fy) # camera params\n",
    "\n",
    "    # Network I/O params\n",
    "    num_labels = 2*num_keypoints+3 # +2 for width, height, +1 for object class\n",
    "    errs_2d = []  # to save\n",
    "    with open(valid_images) as fp:     # validation file names\n",
    "        tmp_files = fp.readlines()\n",
    "        valid_files = [item.rstrip() for item in tmp_files]\n",
    "\n",
    "    # Compute-related Parameters\n",
    "    use_cuda = True # whether to use cuda or no\n",
    "    kwargs = {'num_workers': 4, 'pin_memory': True} # number of workers etc.\n",
    "\n",
    "    # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode\n",
    "    model = Darknet(cfgfile)\n",
    "    model.load_weights(weightfile)\n",
    "    model.cuda()\n",
    "    model.eval()\n",
    "\n",
    "    # Get the dataloader for the test dataset\n",
    "    valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.width, model.height), shuffle=False, objclass=name, transform=transforms.Compose([transforms.ToTensor(),]))\n",
    "    test_loader   = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, **kwargs) \n",
    "\n",
    "    # Iterate through test batches (Batch size for test data is 1)\n",
    "    logging('Testing {}...'.format(name))\n",
    "    for batch_idx, (data, target) in enumerate(test_loader):\n",
    "        \n",
    "        t1 = time.time()\n",
    "        # Pass data to GPU\n",
    "        if use_cuda:\n",
    "            data = data.cuda()\n",
    "            # target = target.cuda()\n",
    "        \n",
    "        # Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference\n",
    "        data = Variable(data, volatile=True)\n",
    "        t2 = time.time()\n",
    "        \n",
    "        # Forward pass\n",
    "        output = model(data).data  \n",
    "        t3 = time.time()\n",
    "        \n",
    "        # Using confidence threshold, eliminate low-confidence predictions\n",
    "        trgt = target[0].view(-1, num_labels)\n",
    "        all_boxes = get_multi_region_boxes(output, conf_thresh, num_classes, num_keypoints, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)        \n",
    "        t4 = time.time()\n",
    "        \n",
    "        # Iterate through all images in the batch\n",
    "        for i in range(output.size(0)):\n",
    "            \n",
    "            # For each image, get all the predictions\n",
    "            boxes   = all_boxes[i]\n",
    "            \n",
    "            # For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)\n",
    "            truths  = target[i].view(-1, num_labels)\n",
    "            \n",
    "            # Get how many object are present in the scene\n",
    "            num_gts = truths_length(truths)\n",
    "\n",
    "            # Iterate through each ground-truth object\n",
    "            for k in range(num_gts):\n",
    "                box_gt = list()\n",
    "                for j in range(1, num_labels):\n",
    "                    box_gt.append(truths[k][j])\n",
    "                box_gt.extend([1.0, 1.0])\n",
    "                box_gt.append(truths[k][0])\n",
    "                \n",
    "                # If the prediction has the highest confidence, choose it as our prediction\n",
    "                best_conf_est = -sys.maxsize\n",
    "                for j in range(len(boxes)):\n",
    "                    if (boxes[j][2*num_keypoints] > best_conf_est) and (boxes[j][2*num_keypoints+2] == int(truths[k][0])):\n",
    "                        best_conf_est = boxes[j][2*num_keypoints]\n",
    "                        box_pr        = boxes[j]\n",
    "                        match         = corner_confidence(box_gt[:2*num_keypoints], torch.FloatTensor(boxes[j][:2*num_keypoints]))\n",
    "                    \n",
    "                # Denormalize the corner predictions \n",
    "                corners2D_gt = np.array(np.reshape(box_gt[:2*num_keypoints], [-1, 2]), dtype='float32')\n",
    "                corners2D_pr = np.array(np.reshape(box_pr[:2*num_keypoints], [-1, 2]), dtype='float32')\n",
    "                corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width\n",
    "                corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height               \n",
    "                corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width\n",
    "                corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height\n",
    "                corners2D_gt_corrected = fix_corner_order(corners2D_gt) # Fix the order of corners\n",
    "                \n",
    "                # Compute [R|t] by pnp\n",
    "                objpoints3D = np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32')\n",
    "                K = np.array(intrinsic_calibration, dtype='float32')\n",
    "                R_gt, t_gt = pnp(objpoints3D,  corners2D_gt_corrected, K)\n",
    "                R_pr, t_pr = pnp(objpoints3D,  corners2D_pr, K)\n",
    "                \n",
    "                # Compute pixel error\n",
    "                Rt_gt        = np.concatenate((R_gt, t_gt), axis=1)\n",
    "                Rt_pr        = np.concatenate((R_pr, t_pr), axis=1)\n",
    "                proj_2d_gt   = compute_projection(vertices, Rt_gt, intrinsic_calibration) \n",
    "                proj_2d_pred = compute_projection(vertices, Rt_pr, intrinsic_calibration) \n",
    "                proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, intrinsic_calibration)) \n",
    "                proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, intrinsic_calibration)) \n",
    "                norm         = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)\n",
    "                pixel_dist   = np.mean(norm)\n",
    "                errs_2d.append(pixel_dist)\n",
    "\n",
    "        t5 = time.time()\n",
    "\n",
    "    # Compute 2D projection score\n",
    "    eps = 1e-5\n",
    "    for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:\n",
    "        acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
    "        # Print test statistics\n",
    "        logging('   Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))\n",
    "\n",
    "modelcfg = 'cfg/yolo-pose-multi.cfg'\n",
    "datacfg = 'cfg/ape_occlusion.data'\n",
    "weightfile = 'backup_multi/model_backup.weights'\n",
    "\n",
    "valid(datacfg, modelcfg, weightfile)\n",
    "datacfg = 'cfg/can_occlusion.data'\n",
    "valid(datacfg, modelcfg, weightfile)\n",
    "datacfg = 'cfg/cat_occlusion.data'\n",
    "valid(datacfg, modelcfg, weightfile)\n",
    "datacfg = 'cfg/duck_occlusion.data'\n",
    "valid(datacfg, modelcfg, weightfile)\n",
    "datacfg = 'cfg/glue_occlusion.data'\n",
    "valid(datacfg, modelcfg, weightfile)\n",
    "datacfg = 'cfg/holepuncher_occlusion.data'\n",
    "valid(datacfg, modelcfg, weightfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
